# run_all_models.py
# Requirements:
#   - models_list.txt (created by discover_models_flat.py, full model paths, one per line)
#   - modelloop.json (workflow exported from ComfyUI, using Tea_ImageCheckpointFromPath)
# Usage:
#   1. Start ComfyUI
#   2. Run:  python run_all_models.py


from pathlib import Path
import json, time, traceback, requests

# create positive.txt / negative.txt for overwriting workflow
POS = Path("positive.txt").read_text(encoding="utf-8").strip() if Path("positive.txt").exists() else None
NEG = Path("negative.txt").read_text(encoding="utf-8").strip() if Path("negative.txt").exists() else None



# --- adjust if you named things differently ---
COMFY        = "http://127.0.0.1:8188"
TEMPLATE     = "modelloop.json"     # your workflow saved in API format, using Tea_ImageCheckpointFromPath
MODELS_TXT   = "models_list.txt"    # generated by discover_models_flat.py (full paths, one per line)
LOADER_CLASS = "Tea_ImageCheckpointFromPath"
SAVE_CLASSES = {"SaveImage"}        # works with stock SaveImage;

# -------- helpers --------
def nowstamp():
    return time.strftime("%Y%m%d-%H%M%S")

def deep_copy(d):
    return json.loads(json.dumps(d))

def read_lines(path: Path):
    return [ln.strip() for ln in path.read_text(encoding="utf-8").splitlines()
            if ln.strip() and not ln.strip().startswith("#")]

def load_prompt(path: Path) -> dict:
    return json.loads(path.read_text(encoding="utf-8"))

def _nodes_iter(prompt: dict):
    return prompt["nodes"].items() if isinstance(prompt.get("nodes"), dict) else prompt.items()

def set_filename_prefix(prompt: dict, prefix: str) -> bool:
    changed = False
    for _, node in _nodes_iter(prompt):
        if node.get("class_type") in SAVE_CLASSES:
            # case A: prefix is an input field
            if "inputs" in node and "filename_prefix" in node["inputs"]:
                node["inputs"]["filename_prefix"] = prefix
                changed = True
            # case B: prefix lives in widgets_values[0]
            elif "widgets_values" in node:
                w = node["widgets_values"]
                if not w: node["widgets_values"] = [prefix]
                else:     node["widgets_values"][0] = prefix
                changed = True
            # optional subfolder input (Advanced Save)
            if "inputs" in node and "subfolder" in node["inputs"]:
                node["inputs"]["subfolder"] = prefix.split("/")[0]
    return changed

def set_prompts(prompt: dict, positive: str | None = None, negative: str | None = None) -> bool:
    # find KSampler (assumes one; extend if needed)
    ks_id, ks = next(((nid, n) for nid, n in _nodes_iter(prompt) if n.get("class_type") == "KSampler"), (None, None))
    if not ks: return False

    pos_ref = ks["inputs"].get("positive")  # like ["6", 0]
    neg_ref = ks["inputs"].get("negative")  # like ["7", 0]

    changed = False
    # helper to set text of a node id if it is CLIPTextEncode
    def _set_text(node_id: str, new_text: str | None):
        nonlocal changed
        if not new_text: return
        node = prompt["nodes"][node_id] if "nodes" in prompt else prompt[node_id]
        if node.get("class_type") == "CLIPTextEncode":
            node.setdefault("inputs", {})["text"] = new_text
            changed = True

    if isinstance(pos_ref, list) and len(pos_ref) >= 1:
        _set_text(str(pos_ref[0]), positive)
    if isinstance(neg_ref, list) and len(neg_ref) >= 1:
        _set_text(str(neg_ref[0]), negative)

    return changed

def set_ckpt_path(prompt: dict, model_path: str) -> bool:
    found = False
    for _, node in _nodes_iter(prompt):
        if node.get("class_type") == LOADER_CLASS:
            node.setdefault("inputs", {})["ckpt_path"] = model_path
            found = True
    return found

def enqueue(prompt: dict) -> str:
    r = requests.post(f"{COMFY}/prompt", json={"prompt": prompt})
    r.raise_for_status()
    return r.json().get("prompt_id", "")

def wait_until_idle(poll_s=1.2):
    while True:
        q = requests.get(f"{COMFY}/queue").json()
        if q.get("pending", 0) == 0 and q.get("running", 0) == 0:
            return
        time.sleep(poll_s)

# -------- main --------
def main():
    here = Path(__file__).parent
    tpl_path = here / TEMPLATE
    list_path = here / MODELS_TXT

    if not tpl_path.exists():
        raise FileNotFoundError(f"Template not found: {tpl_path}")
    if not list_path.exists():
        raise FileNotFoundError(f"Model list not found: {list_path}")

    models = read_lines(list_path)
    if not models:
        print("! models_list.txt is empty.")
        return

    run = nowstamp()
    ran_log   = here / f"models_ran_{run}.txt"
    fail_log  = here / f"failed_models_{run}.txt"

    tpl = load_prompt(tpl_path)
    ran = []
    failed = []

    # choose your prompts here (or read from a txt/json)
    POS = "a wonderful woman in boots and stockings, casual background, sweet smile"
    NEG = "teen, child, anime, doll, overly smooth skin"

    print(f"Queuing {len(models)} models. Run folder: {run}")
    for m in models:
        model_path = Path(m)
        stem = model_path.stem
        prefix = f"{run}/{stem}__"
        print(f"\n=== {stem} ===")

        try:
            prompt = deep_copy(tpl)                     # <-- create per-run copy
            if not set_ckpt_path(prompt, str(model_path)):
                print("! Loader node not found in template; skipping.")
                failed.append(m)
                continue

            set_filename_prefix(prompt, prefix)         # set per-run save prefix
            set_prompts(prompt, positive=POS, negative=NEG)  # <-- MOVE HERE

            pid = enqueue(prompt)
            print(f"queued prompt_id={pid}  prefix='{prefix}'")
            wait_until_idle()
            print("✓ done")
            ran.append(m)
        except requests.HTTPError as e:
            print(f"! HTTP error on {stem}: {e}")
            traceback.print_exc()
            failed.append(m)
        except Exception as e:
            print(f"! failed on {stem}: {e}")
            traceback.print_exc()
            failed.append(m)

    # logs
    if ran:
        ran_log.write_text("\n".join(ran), encoding="utf-8")
        print(f"\nSaved ran list → {ran_log}")
    if failed:
        fail_log.write_text("\n".join(failed), encoding="utf-8")
        print(f"Some failed; saved list → {fail_log}")

if __name__ == "__main__":
    main()
